import os
import random
import time

from tools.utils import write_json, read_json


class DeviceManager:
    def __init__(self, default_device='none'):
        self.device_log = 'log/device_used.json'
        self.default_device = default_device
        self.available_device = None
        self.device_total = list(range(8))

    def require(self):
        if self.default_device != 'none':
            os.environ['CUDA_VISIBLE_DEVICES'] = self.default_device
            return

        if not os.path.exists(self.device_log):
            write_json(self.device_log, [])

        while self.available_device is None:
            time.sleep(random.random())
            device_used = read_json(self.device_log)
            for id in self.device_total:
                if id not in device_used:
                    self.available_device = id
        device_used.append(self.available_device)
        write_json(self.device_log, device_used)
        os.environ['CUDA_VISIBLE_DEVICES'] = str(self.available_device % 8)

    def release(self):
        if self.default_device != 'none':
            return
        device_log = 'log/device_used.json'
        device_used: list = read_json(device_log)
        if self.available_device in device_used:
            device_used.remove(self.available_device)
        write_json(device_log, device_used)

    def __enter__(self):
        self.require()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.release()

    def reset(self):
        if not os.path.exists(self.device_log):
            return
        os.remove(self.device_log)

    def set_devices(self, devices):
        self.device_total = devices
